import tensorflow as tf
import cv2
import numpy as np

sess = tf.InteractiveSession()

x = tf.placeholder("float",shape = [None,784])
y_ = tf.placeholder("float",shape = [None,10])

def conv2d(name, x, W, b):
    return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME'),b),name=name)

def max_pool2x2(name, x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name=name)

def norm(name, x):
    return tf.nn.lrn(x, 4, bias=1.0, alpha=0.01/9.0, beta=0.75, name=name)

def alex_net(x, Weights, Biases):
    x = tf.reshape(x, shape=[-1,28,28,1])

    # convlution layer 1
    conv1 = conv2d('conv1', x, Weights['wc1'], Biases['bc1'])
    pool1 = max_pool2x2('pool1', conv1)
    #norm1 = norm('norm1', pool1)
    #norm1 = tf.nn.dropout(norm1, Dropout)

    # convlution layer 2
    conv2 = conv2d('conv2', pool1, Weights['wc2'], Biases['bc2'])
    pool2 = max_pool2x2('pool2', conv2)
    #norm2 = norm('norm2', pool2)
    #norm2 = tf.nn.dropout(norm2, Dropout)

    #convlution layer 3
    conv3 = conv2d('conv3', pool2, Weights['wc3'], Biases['bc3'])
    pool3 = max_pool2x2('pool3', conv3)
    #norm3 = norm('norm3', pool3)
    #norm3 = tf.nn.dropout(norm3, Dropout)

    #full connect layer1
    dense1 = tf.reshape(pool3, [-1, Weights['wd1'].get_shape().as_list()[0]])
    dense1 = tf.nn.relu(tf.matmul(dense1, Weights['wd1']) + Biases['bd1'], name='fc1')

    # full connect layer2
    dense2 = tf.nn.relu(tf.matmul(dense1, Weights['wd2']) + Biases['bd2'], name='fc2')

    #output layer
    out = tf.matmul(dense2, Weights['out']) + Biases['out']

    return out

weights = {
    'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], dtype=tf.float32)),
    'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32)),
    'wc3': tf.Variable(tf.truncated_normal([3, 3, 128, 256], dtype=tf.float32)),
    'wd1': tf.Variable(tf.truncated_normal([4 * 4 * 256, 1024], dtype=tf.float32)),
    'wd2': tf.Variable(tf.truncated_normal([1024, 1024], dtype=tf.float32)),
    'out': tf.Variable(tf.truncated_normal([1024, 10], dtype=tf.float32))
}
biases = {
    'bc1': tf.Variable(tf.random_normal([64]), dtype=tf.float32),
    'bc2': tf.Variable(tf.random_normal([128]), dtype=tf.float32),
    'bc3': tf.Variable(tf.random_normal([256]), dtype=tf.float32),
    'bd1': tf.Variable(tf.random_normal([1024]), dtype=tf.float32),
    'bd2': tf.Variable(tf.random_normal([1024]), dtype=tf.float32),
    'out': tf.Variable(tf.random_normal([10]), dtype=tf.float32)
}

saver = tf.train.Saver()

if __name__ == '__main__':
    saver.restore(sess, "model/AlexNet.ckpt")
    print("read model successfully")
    count = 0
    for j in range(8):
        for i in range(10):
            dir = 'data/data2/%s.%s.jpg'%(i,j+1)
            img = cv2.imread(dir)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            img = cv2.GaussianBlur(img,(3,3),0)
            img_array = np.array(img)
            im_data = np.array(np.reshape(img_array, [28, 28]) * 255, dtype=np.float32)
            x = tf.convert_to_tensor(im_data)
            x = tf.reshape(x, [-1,28,28,1])
            y = alex_net(x, weights, biases)
            output = list(y.eval(session=sess))[0]
            output = output.tolist()
            print(output.index(max(output)), ' and ', i)
            if(output.index(max(output)) == i):
                count += 1
    res = count/80.0
    print('accuracy is %.1f%%'%(res*100))